{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed question 0\n",
      "Processed question 1\n",
      "Processed question 2\n",
      "Processed question 3\n",
      "Processed question 4\n",
      "Processed question 5\n",
      "Processed question 6\n",
      "Processed question 7\n",
      "Processed question 8\n",
      "Processed question 9\n",
      "Processed question 10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed question 11\n",
      "Processed question 12\n",
      "Processed question 13\n",
      "Processed question 14\n",
      "Processed question 15\n",
      "Processed question 16\n",
      "Processed question 17\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import argparse\n",
    "from tqdm import tqdm\n",
    "\n",
    "save_dir = '<path_to_your_save_dir>'\n",
    "\n",
    "#Update these values according to what you had chosen for response generation\n",
    "question_num = 15\n",
    "step_size = 256\n",
    "trials = 3\n",
    "limit = 16384\n",
    "\n",
    "All_data = {}\n",
    "\n",
    "for question in range(0, question_num):\n",
    "    for chunk in range(step_size, limit + step_size, step_size):\n",
    "        file_name = save_dir + f\"/question_{question}_tokens_{chunk}.json\"\n",
    "        try:\n",
    "            with open(file_name, 'r') as f:\n",
    "                data = json.load(f)\n",
    "        except FileNotFoundError:\n",
    "            continue\n",
    "            raise FileNotFoundError(f\"File {file_name} not found\")\n",
    "        All_data[f\"question_{question}_tokens_{chunk}\"] = data\n",
    "    print(f\"Processed question {question}\")\n",
    "with open(save_dir + f\"/All_data.json\", 'w') as f:\n",
    "    json.dump(All_data, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "def obtaint_answer(s):\n",
    "    # Find first unpaired } by counting { and }\n",
    "    stack = []\n",
    "    for i, c in enumerate(s):\n",
    "        if c == '{':\n",
    "            stack.append(c)\n",
    "        elif c == '}':\n",
    "            if not stack:  # No matching { found\n",
    "                return s[:i]\n",
    "            stack.pop()\n",
    "    return \"\"\n",
    "\n",
    "\n",
    "def preprocess_data(all_data, problems):\n",
    "\n",
    "    problem_sets = {}\n",
    "    for i in range(problems):\n",
    "        for r in range(10):\n",
    "            problem_sets[f\"{i}_{r}\"] = {\n",
    "                'target' : None,\n",
    "                'is_finished' : [],\n",
    "                'score' : [],\n",
    "                'score_guide' : [],\n",
    "                'probe_answers' : [],\n",
    "                'probe_responses' : [],\n",
    "                'probe_prompts' : [],\n",
    "                'new_tokens' : [],\n",
    "                'prompts' : []\n",
    "            }\n",
    "\n",
    "    for i in tqdm(range(problems)):\n",
    "        for j in range(step_size, limit + step_size, step_size):\n",
    "            if True:\n",
    "                json_data = all_data[f\"question_{i }_tokens_{j}\"]\n",
    "                target = json_data.get(\"target\", None)\n",
    "                is_finished = json_data.get(\"is_finished\", None)\n",
    "                score = json_data.get(\"is_corrects_original\", None)\n",
    "                score_guide = json_data.get(\"is_corrects\", None)\n",
    "                guided_answers = [obtaint_answer(prompt) for prompt in json_data.get(\"probe_responses\", None)]\n",
    "                probe_responses = json_data.get(\"probe_responses\", None)\n",
    "                probe_prompts = json_data.get(\"probe_prompts\", None)\n",
    "                responses = json_data.get(\"new_tokens\", None)\n",
    "                prompts = json_data.get(\"prompts\", None)\n",
    "                #print(len(data))\n",
    "                for r in range(trials):\n",
    "                    problem_sets[f\"{i}_{r}\"]['target'] = target\n",
    "                    problem_sets[f\"{i}_{r}\"]['is_finished'].append(is_finished[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['score'].append(score[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['score_guide'].append(score_guide[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['probe_answers'].append(guided_answers[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['probe_responses'].append(probe_responses[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['probe_prompts'].append(probe_prompts[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['new_tokens'].append(responses[r])\n",
    "                    problem_sets[f\"{i}_{r}\"]['prompts'].append(prompts[r])\n",
    "\n",
    "    return problem_sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "math500_data = json.load(open(f'<path_to_your_save_dir>/All_data.json', 'r'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:00<00:00, 1647.01it/s]\n"
     ]
    }
   ],
   "source": [
    "problem_set_7b_math500 = preprocess_data(math500_data, question_num)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "problem_set_outputs = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "def standard_accuracy(problem_sets, problems,max_tokens=limit):\n",
    "    all_scores = []\n",
    "    all_scores_guide = []\n",
    "    all_token_counts = []\n",
    "    tokens_per_problem = [[] for _ in range(problems)]\n",
    "    corrects_per_problem = [[] for _ in range(problems)]\n",
    "\n",
    "    score_guide_dict = defaultdict(lambda: 0)\n",
    "    score_dict = defaultdict(lambda: 0)\n",
    "    for run in range(trials):\n",
    "        token_counts = []\n",
    "        scores = []\n",
    "        scores_guide = []\n",
    "        for problem_id in range(problems):\n",
    "            ended = False\n",
    "            for step in range(0, max_tokens, step_size):\n",
    "                step_id = step // step_size                \n",
    "                finished = problem_sets[f\"{problem_id}_{run}\"]['is_finished'][step_id]\n",
    "                score = problem_sets[f\"{problem_id}_{run}\"]['score'][step_id]\n",
    "                score_guide = problem_sets[f\"{problem_id}_{run}\"]['score_guide'][step_id]\n",
    "\n",
    "                if finished:\n",
    "                    score_guide = score\n",
    "                    token_counts.append(step + step_size)\n",
    "                    ended = True\n",
    "                    break \n",
    "            if not ended:\n",
    "                token_counts.append(max_tokens)\n",
    "                score = problem_sets[f\"{problem_id}_{run}\"]['score'][max_tokens // step_size - 1]\n",
    "                score_guide = problem_sets[f\"{problem_id}_{run}\"]['score_guide'][max_tokens // step_size - 1]\n",
    "\n",
    "            for step in range(0, max_tokens, step_size):\n",
    "                #datas[problem_id][step] += int(score_guide)\n",
    "                step_id = step // step_size\n",
    "                score_dict[f\"{problem_id}_{step}\"] += problem_sets[f\"{problem_id}_{run}\"]['score'][step_id]\n",
    "                score_guide_dict[f\"{problem_id}_{step}\"] += problem_sets[f\"{problem_id}_{run}\"]['score_guide'][step_id]\n",
    "\n",
    "            #correct_dict[(problem_id, run)] = score\n",
    "            scores.append( score)\n",
    "            scores_guide.append(score_guide)\n",
    "            tokens_per_problem[problem_id].append(token_counts[-1])\n",
    "            corrects_per_problem[problem_id].append(score)\n",
    "        all_scores.extend(scores)\n",
    "        all_token_counts.extend(token_counts)\n",
    "        all_scores_guide.extend(scores_guide)\n",
    "    return all_scores, all_scores_guide, all_token_counts, score_dict, score_guide_dict, tokens_per_problem, corrects_per_problem\n",
    "\n",
    "def store_standard_outputs(problem_set_name, problem_set, problems, max_tokens=limit):\n",
    "    if problem_set_name not in problem_set_outputs:\n",
    "        problem_set_outputs[problem_set_name] = {}\n",
    "    scores = []\n",
    "    scores_guide = []\n",
    "    tokens_counts = []\n",
    "    score_dict = []\n",
    "    score_guide_dict = []\n",
    "    tokens_per_problem = []\n",
    "    corrects_per_problem = []\n",
    "    for tokens in tqdm(range(1024, max_tokens + 1024, 1024)): \n",
    "    # Store standard accuracy results\n",
    "        #print('tks: ', tokens)\n",
    "        standard_outputs = standard_accuracy(\n",
    "            problem_set,\n",
    "            problems, \n",
    "            tokens,\n",
    "        )\n",
    "        scores.append(standard_outputs[0])\n",
    "        scores_guide.append(standard_outputs[1])\n",
    "        tokens_counts.append(standard_outputs[2])\n",
    "        score_dict.append(standard_outputs[3])\n",
    "        score_guide_dict.append(standard_outputs[4])\n",
    "        tokens_per_problem.append(standard_outputs[5])\n",
    "        corrects_per_problem.append(standard_outputs[6])\n",
    "    problem_set_outputs[problem_set_name][f\"standard_0\"] = {\n",
    "        'scores': scores,\n",
    "        'scores_guide': scores_guide,\n",
    "        'token_counts': tokens_counts,\n",
    "        'score_dict': score_dict,\n",
    "        'score_guide_dict': score_guide_dict,\n",
    "        'tokens_per_problem': tokens_per_problem,\n",
    "        'corrects_per_problem': corrects_per_problem\n",
    "    }\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from collections import Counter, defaultdict\n",
    "from typing import List\n",
    "import numpy as np\n",
    "from dynasor.core.evaluator import (\n",
    "    extract_answer,\n",
    "    strip_string,\n",
    "    math_equal,\n",
    "    extract_first_boxed_answer,\n",
    ")\n",
    "def entropy(Plist):\n",
    "    if len(Plist):\n",
    "        result = 0\n",
    "        for x in Plist:\n",
    "            result += (-x) * math.log(x, 2)\n",
    "        return result\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "def norm(Olist):\n",
    "    s = sum(Olist)\n",
    "    return [o / s for o in Olist]\n",
    "\n",
    "def count(Olist):\n",
    "    x_dict = defaultdict(lambda: 0.0)\n",
    "    for x in Olist:\n",
    "        x_dict[x] += 1\n",
    "    cc = [c for _,c in x_dict.items()]\n",
    "    #print(cc)\n",
    "    return cc\n",
    "\n",
    "def item_entropy(answers: List) -> float:\n",
    "    return entropy(norm(count(answers)))\n",
    "\n",
    "def list_equal(l):\n",
    "    equal_group = l[0]\n",
    "    for i in range(1, len(l)):\n",
    "        if not math_equal(l[i], equal_group):\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def count_not_empty(answers):\n",
    "    return sum(1 for answer in answers if answer != '')\n",
    "\n",
    "def majority_voting(answers):\n",
    "    equiv_classes = []\n",
    "    equiv_weights = []\n",
    "    max_vote = 0\n",
    "    for answer in answers:\n",
    "        weight = 1\n",
    "        flag = 0\n",
    "        for i, rep in enumerate(equiv_classes):\n",
    "            if math_equal(answer,rep):\n",
    "                flag = 1\n",
    "                equiv_weights[i] = equiv_weights[i]+weight\n",
    "                if equiv_weights[i] > max_vote:\n",
    "                    max_vote = equiv_weights[i]\n",
    "                    max_rep = answer\n",
    "                break\n",
    "        if flag:\n",
    "            continue\n",
    "        equiv_classes.append(answer)\n",
    "        equiv_weights.append(weight)\n",
    "        if max_vote == 0:\n",
    "            max_vote = weight\n",
    "            max_rep = answer\n",
    "    return max_rep\n",
    "\n",
    "uncertain_words = ['wait', 'hold', 'but', 'okay', 'no', 'hmm']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def earlyexit_accuracy(problem_sets, problems, jump = 4, continue_certain_bar = 2, warmup_steps = 0, max_tokens=limit, response_len=10, skip_wait=False):\n",
    "    all_scores = []\n",
    "    all_token_counts = []\n",
    "    tokens_per_problem_clip = [[] for _ in range(problems)]\n",
    "    corrects_per_problem_clip = [[] for _ in range(problems)]\n",
    "    score_dict = defaultdict(lambda: 0)\n",
    "\n",
    "    for run in range(trials):\n",
    "        token_counts = []\n",
    "        scores = []\n",
    "        for problem_id in range(problems):\n",
    "            ended = False\n",
    "            end_step = 100000\n",
    "            for step in range(0, max_tokens, step_size):\n",
    "                step_id = step // step_size                \n",
    "                finished = problem_sets[f\"{problem_id}_{run}\"]['is_finished'][step_id]\n",
    "                score = problem_sets[f\"{problem_id}_{run}\"]['score'][step_id]\n",
    "\n",
    "                if finished:\n",
    "                    token_counts.append(step + step_size)\n",
    "                    ended = True\n",
    "                    break \n",
    "\n",
    "                clip_responses = problem_sets[f\"{problem_id}_{run}\"]['probe_responses'][:step_id + 1:jump]\n",
    "                clip_answers = [obtaint_answer(prompt) for prompt in clip_responses]\n",
    "                #print(clip_answers, clip_responses[0])\n",
    "                certain_count = [not any(word in res.lower() for word in uncertain_words) for res in clip_responses[-continue_certain_bar:]]\n",
    "                \n",
    "            \n",
    "                if step >= warmup_steps and item_entropy(clip_answers[-continue_certain_bar:]) <= 0.01 and count_not_empty(clip_answers[-continue_certain_bar:]) == continue_certain_bar and (skip_wait or sum(certain_count) == continue_certain_bar) :\n",
    "                    token_counts.append(step + step_size)\n",
    "                    end_step = step\n",
    "                    ended = True\n",
    "                    score = math_equal(clip_answers[-1], problem_sets[f\"{problem_id}_{run}\"]['target'])\n",
    "                    break\n",
    "\n",
    "            if not ended:\n",
    "                token_counts.append(max_tokens)\n",
    "                score = problem_sets[f\"{problem_id}_{run}\"]['score_guide'][max_tokens // step_size - 1]\n",
    "\n",
    "            \n",
    "            for step in range(0, max_tokens, step_size):\n",
    "                #datas[problem_id][step] += int(score_guide)\n",
    "                step_id = step // step_size\n",
    "                if step >= end_step:\n",
    "                    score_dict[f\"{problem_id}_{step}\"] += score\n",
    "                else:\n",
    "                    score_dict[f\"{problem_id}_{step}\"] += problem_sets[f\"{problem_id}_{run}\"]['score_guide'][step_id]\n",
    "\n",
    "            #correct_dict[(problem_id, run)] = score\n",
    "            scores.append( score)\n",
    "\n",
    "            tokens_per_problem_clip[problem_id].append(token_counts[-1])\n",
    "            corrects_per_problem_clip[problem_id].append(scores[-1])\n",
    "\n",
    "        all_scores.extend(scores)\n",
    "        all_token_counts.extend(token_counts)\n",
    "    return all_scores, all_token_counts, score_dict, tokens_per_problem_clip, corrects_per_problem_clip\n",
    "\n",
    "def store_earlyexit_outputs(problem_set_name, problem_set, problems, configs, max_tokens=limit, skip_wait=False):\n",
    "    if problem_set_name not in problem_set_outputs:\n",
    "        problem_set_outputs[problem_set_name] = {}\n",
    "    for config in tqdm(configs):\n",
    "        key = f\"jump{config[0]}_bar{config[1]}_skipwait{skip_wait}\"\n",
    "        outputs = earlyexit_accuracy(\n",
    "            problem_set, \n",
    "            problems,\n",
    "            jump=config[0],\n",
    "            continue_certain_bar=config[1], \n",
    "            warmup_steps=0,\n",
    "            max_tokens=max_tokens,\n",
    "            response_len=10,\n",
    "            skip_wait=skip_wait\n",
    "        )\n",
    "        #print(outputs[2], outputs[3], outputs[4])\n",
    "        problem_set_outputs[problem_set_name][key] = {\n",
    "            'scores': outputs[0],\n",
    "            'token_counts': outputs[1], \n",
    "            'score_dict': outputs[2],\n",
    "            'tokens_per_problem': outputs[3],\n",
    "            'corrects_per_problem': outputs[4]\n",
    "        }\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIGS_1 = [\n",
    "    (1, 2),\n",
    "    (1, 3),\n",
    "    (1, 5),\n",
    "    (1, 8),\n",
    "    (1, 10),]\n",
    "\n",
    "CONFIGS_2 = [\n",
    "    (2, 2),\n",
    "    (2, 3),\n",
    "    (2, 5),\n",
    "    (2, 8),\n",
    "    (2, 10),]\n",
    "\n",
    "CONFIGS_4 = [\n",
    "    (4, 2),\n",
    "    (4, 3),\n",
    "    (4, 5),\n",
    "    (4, 8),\n",
    "    (4, 10),]\n",
    "\n",
    "CONFIGS_8 = [\n",
    "    (8, 2),\n",
    "    (8, 3),\n",
    "    (8, 5),\n",
    "    (8, 8),\n",
    "    (8, 10),    \n",
    "]\n",
    "\n",
    "CONFIGS_10 = [\n",
    "    (10, 2),\n",
    "    (10, 3),\n",
    "    (10, 5),\n",
    "    (10, 8),\n",
    "    (10, 10),\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 16/16 [00:00<00:00, 402.85it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00,  9.37it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 31.00it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 47.38it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 56.15it/s]\n",
      "100%|██████████| 5/5 [00:00<00:00, 101.11it/s]\n"
     ]
    }
   ],
   "source": [
    "store_standard_outputs('7b_math500', problem_set_7b_math500, question_num)\n",
    "\n",
    "\n",
    "store_earlyexit_outputs('7b_math500', problem_set_7b_math500, 10, CONFIGS_1)\n",
    "store_earlyexit_outputs('7b_math500', problem_set_7b_math500, 10, CONFIGS_2)\n",
    "store_earlyexit_outputs('7b_math500', problem_set_7b_math500, 10, CONFIGS_4)\n",
    "store_earlyexit_outputs('7b_math500', problem_set_7b_math500, 10, CONFIGS_8)\n",
    "store_earlyexit_outputs('7b_math500', problem_set_7b_math500, 10, CONFIGS_10)\n",
    "\n",
    "\n",
    "with open('problem_set_outputs.json', 'w') as f:\n",
    "    json.dump(problem_set_outputs, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "# Create 1x2 subplot grid with shared legend\n",
    "fig = plt.figure(figsize=(10, 3))\n",
    "gs = fig.add_gridspec(1, 1, hspace=0.2, wspace=0.4)  # Increased wspace for larger gap between columns\n",
    "axs = gs.subplots()\n",
    "\n",
    "\n",
    "# Define colors and styles for each detection interval\n",
    "detection_styles = {\n",
    "    1: {'color': '#1f77b4', 'linestyle': '-', 'label': 'Ours (32)'},\n",
    "    2: {'color': '#2ca02c', 'linestyle': '-', 'label': 'Ours (64)'},\n",
    "    4: {'color': '#ff7f0e', 'linestyle': '-', 'label': 'Ours (128)'},\n",
    "    8: {'color': '#d62728', 'linestyle': '-', 'label': 'Ours (256)'},\n",
    "    10: {'color': '#9467bd', 'linestyle': '-', 'label': 'Ours (320)'}\n",
    "}\n",
    "\n",
    "# Create legend elements\n",
    "legend_lines = [plt.Line2D([0], [0], color='black', linestyle='--', label='Baseline')]\n",
    "for config_type in detection_styles:\n",
    "    style = detection_styles[config_type]\n",
    "    legend_lines.append(plt.Line2D([0], [0], **style))\n",
    "\n",
    "model = '7b'\n",
    "dataset = 'math500'\n",
    "# Plot baseline (unfinished) scores\n",
    "scores_unfinished = [np.mean(s) for s in problem_set_outputs[f'{model}_{dataset}']['standard_0']['scores']]\n",
    "tokens_unfinished = [np.mean(t) for t in problem_set_outputs[f'{model}_{dataset}']['standard_0']['token_counts']]\n",
    "baseline_line = plt.plot(tokens_unfinished, scores_unfinished, color='black', linestyle='--', alpha=0.8)[0]\n",
    "\n",
    "# Plot detection interval results\n",
    "configs = CONFIGS_1 + CONFIGS_2 + CONFIGS_4 + CONFIGS_8 + CONFIGS_10\n",
    "config_points = {1: [], 2: [], 4: [], 8: [], 10: []}\n",
    "\n",
    "for config in configs:\n",
    "    scores = problem_set_outputs[f'{model}_{dataset}'][f'jump{config[0]}_bar{config[1]}_skipwaitFalse']['scores']\n",
    "    token_counts = problem_set_outputs[f'{model}_{dataset}'][f'jump{config[0]}_bar{config[1]}_skipwaitFalse']['token_counts']\n",
    "    config_points[config[0]].append((np.mean(token_counts), np.mean(scores)))\n",
    "\n",
    "# Plot all lines\n",
    "for interval in config_points:\n",
    "    points = config_points[interval]\n",
    "    if len(points) > 1:\n",
    "        points.sort(key=lambda x: x[0])\n",
    "        tokens, scores = zip(*points)\n",
    "        style = detection_styles[interval]\n",
    "        line = plt.plot(tokens, scores, color=style['color'],\n",
    "                            linestyle=style['linestyle'], alpha=0.8)[0]\n",
    "\n",
    "plt.xlabel('Tokens', fontsize=10)\n",
    "plt.ylabel(f'{dataset.upper()}\\nAccuracy', fontsize=10)\n",
    "#axs[j].set_title(f'{model.upper()}')\n",
    "plt.grid(True, alpha=0.3)\n",
    "# Add shared legend at the bottom, moved higher up\n",
    "fig.legend(handles=legend_lines, loc='lower center', \n",
    "          ncol=6, frameon=False, bbox_to_anchor=(0.5, -0.12))\n",
    "plt.tight_layout()\n",
    "plt.savefig('token_deprivation_comparison_r1.pdf', bbox_inches='tight', dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for problem_id in range(question_num):\n",
    "    for step in range(0, limit, step_size):\n",
    "        data.append({\n",
    "            \"Problem ID\": problem_id,\n",
    "            \"Token Budget\": step,\n",
    "            \"Score\": problem_set_outputs['7b_math500']['standard_0']['score_dict'][-1][f\"{problem_id}_{step}\"]\n",
    "        })\n",
    "\n",
    "\n",
    "data_guide = []\n",
    "for problem_id in range(question_num):\n",
    "    for step in range(0, limit, step_size):\n",
    "        data_guide.append({\n",
    "            \"Problem ID\": problem_id,\n",
    "            \"Token Budget\": step,\n",
    "            \"Score\": problem_set_outputs['7b_math500']['standard_0']['score_guide_dict'][-1][f\"{problem_id}_{step}\"]\n",
    "        })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import pandas as pd\n",
    "import json\n",
    "import os\n",
    "import glob\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load data from JSON files\n",
    "# with open('data-cot.json', 'r') as f:\n",
    "#     data = json.load(f)\n",
    "\n",
    "# with open('data_guide-cot.json', 'r') as f:\n",
    "#     data_guide = json.load(f)\n",
    "\n",
    "# Convert pivot table to pandas DataFrame\n",
    "df = pd.DataFrame(data)\n",
    "df_guide = pd.DataFrame(data_guide)\n",
    "print (df.head())\n",
    "print (f\"You have {len(df)} rows\")\n",
    "pivot_table= df\n",
    "pivot_table = pd.pivot_table(df, values='Score', index=['Problem ID', 'Token Budget'], aggfunc='mean').reset_index() # This will aggregate\n",
    "pivot_table = pivot_table.pivot(index=\"Problem ID\", columns=\"Token Budget\", values=\"Score\") # This will turn into a proper pivot\n",
    "pivot_table.iloc[:5, :5]\n",
    "\n",
    "pivot_table_guide = pd.DataFrame(data_guide)\n",
    "pivot_table_guide = pd.pivot_table(df_guide, values='Score', index=['Problem ID', 'Token Budget'], aggfunc='mean').reset_index() # This will aggregate\n",
    "pivot_table_guide = pivot_table_guide.pivot(index=\"Problem ID\", columns=\"Token Budget\", values=\"Score\") # This will turn into a proper pivot\n",
    "\n",
    "# Calculate row sums for pivot_table and get the sorting order\n",
    "row_sums = pivot_table.sum(axis=1)\n",
    "sort_order = row_sums.argsort()\n",
    "\n",
    "# Sort both pivot tables using the same order from pivot_table\n",
    "pivot_table = pivot_table.iloc[sort_order]\n",
    "pivot_table_guide = pivot_table_guide.iloc[sort_order]\n",
    "\n",
    "# Create a custom colormap. Go to https://coolors.co/ and pick cool colors\n",
    "cmap = LinearSegmentedColormap.from_list(\"custom_cmap\", [\"#F0496E\", \"#EBB839\", \"#0CD79F\"])\n",
    "\n",
    "# Create figure with two subplots side by side\n",
    "plt.figure(figsize=(12.5, 5.5))\n",
    "\n",
    "# First subplot for pivot_table\n",
    "plt.subplot(1, 2, 1)\n",
    "sns.heatmap(\n",
    "    pivot_table,\n",
    "    fmt=\"g\",\n",
    "    cmap=cmap,\n",
    "    cbar_kws={'label': 'Score'}\n",
    ")\n",
    "plt.title('Model Output Scores\\nAcross Token Budgets')\n",
    "plt.xlabel('Token Budget')\n",
    "plt.ylabel('Problem ID')\n",
    "plt.xticks(rotation=45)\n",
    "plt.yticks(rotation=0)\n",
    "\n",
    "# Second subplot for pivot_table_guide\n",
    "plt.subplot(1, 2, 2)\n",
    "sns.heatmap(\n",
    "    pivot_table_guide,\n",
    "    fmt=\"g\", \n",
    "    cmap=cmap,\n",
    "    cbar_kws={'label': 'Score'}\n",
    ")\n",
    "plt.title('Prompt-In-The-Middle Output Scores\\nAcross Token Budgets')\n",
    "plt.xlabel('Token Budget')\n",
    "plt.ylabel('Problem ID')\n",
    "plt.xticks(rotation=45)\n",
    "plt.yticks(rotation=0)\n",
    "\n",
    "plt.suptitle('Pressure Testing R1-14B on HMMT_2025 Problem Solving Across Token Budgets (\"Token Deprivation\")')\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plots\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sglang-reasoning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
